import torch


def qdot(A, B, p):
    AB = torch.mul(A, B)
    y = -torch.sum(AB[p:]) + torch.sum(AB[:p])
    return y


def rhob_a(a, b, p):
    norma_p = torch.linalg.norm(a[0:p], ord=2)
    normb_p = torch.linalg.norm(b[0:p], ord=2)
    bb = torch.cat([a[:p], b[p:].squeeze() * norma_p / normb_p],dim=0)
    return bb


def dist_UltraE(eh, R, et, p, beta):
    distU13 = dist13(torch.mm(R, eh[:, None]), et, p, beta)
    distU14 = dist14(torch.mm(R, eh[:, None]), et, p, beta)
    return distU13, distU14

def simple_dist_UltraE(eh, R, et, p, beta):
    distU = Sdist(torch.mm(R, eh[:, None]), et, beta, p)
    return distU

def dist13(x, y, p, beta):
    distU = Sdist(y, rhob_a(y, x, p), beta, p) + Hdist(rhob_a(y, x, p), x, beta, p)
    return distU


def dist14(x, y, p, beta):
    distU = Sdist(x, rhob_a(x, y, p), beta, p) + Hdist(rhob_a(x, y, p), y, beta, p)
    return distU


def Sdist(A, B, beta, p):
    beta = torch.tensor(beta)
    temp = qdot(A, B, p) / beta
    if abs(temp) < 1:
        y = torch.sqrt(abs(beta)) * torch.acos(abs(temp))
    else:
        y = torch.sqrt(abs(beta)) * torch.acosh(abs(temp))
    return y


def Hdist(A, B, beta, p):
    beta = torch.tensor(beta)
    temp = qdot(A, B, p) / beta
    if abs(temp) < 1:
        y = torch.sqrt(abs(beta)) * torch.acos(abs(temp))
    else:
        y = torch.sqrt(abs(beta)) * torch.acosh(abs(temp))
    return y
